import os
from datasets import load_dataset
from datasets.features.features import require_decoding
from datasets import config
from datasets.utils.py_utils import convert_file_size_to_int
from datasets.table import embed_table_storage
from tqdm import tqdm
import torch


data_dir = 'lj_speech_parquets'
split = 'train'
max_shard_size = '500MB'

dataset = load_dataset("lj_speech")
train_size = int(len(dataset) * 0.8)
val_size = int(len(dataset) * 0.1)
test_size = int(len(dataset) * 0.1)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

decodable_columns = (
    [k for k, v in dataset.features.items() if require_decoding(v, ignore_decode_attribute=True)]
)
dataset_nbytes = dataset._estimate_nbytes()
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
num_shards = int(dataset_nbytes / max_shard_size) + 1
num_shards = max(num_shards, 1)
shards = (dataset.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards))

def shards_with_embedded_external_files(shards):
    for shard in shards:
        format = shard.format
        shard = shard.with_format("arrow")
        shard = shard.map(
            embed_table_storage,
            batched=True,
            batch_size=1000,
            keep_in_memory=True,
        )
        shard = shard.with_format(**format)
        yield shard
shards = shards_with_embedded_external_files(shards)

os.makedirs(data_dir)

for index, shard in tqdm(
    enumerate(shards),
    desc="Save the dataset shards",
    total=num_shards,
):
    shard_path = f"{data_dir}/{index:05d}-of-{num_shards:05d}.parquet"
    shard.to_parquet(shard_path)